#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
#
# Copyright (C) 2015, 2016, 2017 Daniel Rodriguez
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
###############################################################################
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import collections
import json
from .ccxtorder import CCXTOrder, Okex5CCXTOrder
from backtrader import BrokerBase, Order
from backtrader.position import CCXTPosition
from backtrader.utils.py3 import queue, with_metaclass
from backtrader.stores.okex5store import Okex5Store
from backtrader.utils import logger


class MetaCCXTBroker(BrokerBase.__class__):
    def __init__(cls, name, bases, dct):
        '''Class has already been created ... register'''
        # Initialize the class
        super(MetaCCXTBroker, cls).__init__(name, bases, dct)
        Okex5Store.BrokerCls = cls

"""
1. 普通订单接口
    - 单个下单
    - 批量下单
    - 单个查询
    - 未成交列表
2. 止盈止损订单接口
    - 单个下单（返回algoid）, 当触发价触发的时候, 系统自动提交普通限价单到撮合引擎, 接口返回带有orderid和state=effective[生效],
        这个时候未成交列表就查不到，只能去成交列表查看, 然后查看普通委托单是否成交成功
    - 未成交列表(未成交的时候返回orderid为0), 当成交之后, 调用这个接口查不到, 只能查询已成交
    - 已成交列表（成交成功的时候返回orderid，并且能在普通订单接口调用查询）
"""

'''
如果有新的交易所加进来, 必须这样命名
'''
class Okex5Broker(with_metaclass(MetaCCXTBroker, BrokerBase)):
    '''Broker implementation for CCXT cryptocurrency trading library.
    This class maps the orders/positions from CCXT to the
    internal API of ``backtrader``.

    Broker mapping added as I noticed that there differences between the expected
    order_types and retuned status's from canceling an order

    Added a new mappings parameter to the script with defaults.

    Added a get_balance function. Manually check the account balance and update brokers
    self.cash and self.value. This helps alleviate rate limit issues.

    Added a new get_wallet_balance method. This will allow manual checking of the any coins
        The method will allow setting parameters. Useful for dealing with multiple assets

    Modified getcash() and getvalue():
        Backtrader will call getcash and getvalue before and after next, slowing things down
        with rest calls. As such, th

    The broker mapping should contain a new dict for order_types and mappings like below:

    broker_mapping = {
        'order_types': {
            bt.Order.Market: 'market',
            bt.Order.Limit: 'limit',
            bt.Order.Stop: 'stop-loss', #stop-loss for kraken, stop for bitmex
            bt.Order.StopLimit: 'stop limit'
        },
        'mappings':{
            'closed_order':{
                'key': 'status',
                'value':'closed'
                },
            'canceled_order':{
                'key': 'result',
                'value':1}
                }
        }

    Added new private_end_point method to allow using any private non-unified end point

    '''

    order_types = {Order.Market: 'market',
                   Order.Limit: 'limit',
                   Order.Stop: 'stop',  # stop-loss for kraken, stop for bitmex
                   Order.StopLimit: 'stop limit'}

    mappings = {
        'closed_order': {
            'key': 'status',
            'value': 'closed'
        },
        'canceled_order': {
            'key': 'status',
            'value': 'canceled'}
    }

    def __init__(self, broker_mapping=None, debug=False, **kwargs):
        super(Okex5Broker, self).__init__()

        if broker_mapping is not None:
            try:
                self.order_types = broker_mapping['order_types']
            except KeyError:  # Might not want to change the order types
                pass
            try:
                self.mappings = broker_mapping['mappings']
            except KeyError:  # might not want to change the mappings
                pass

        self.store = Okex5Store(**kwargs)
        self.positions = collections.defaultdict(CCXTPosition)

        self.cash = 0
        self.value = 0

        self.debug = debug
        self.indent = 4  # For pretty printing dictionaries

        self.notifs = queue.Queue()  # holds orders which are notified

        self.open_orders = list()

        self.logger = logger.getLogger(__name__)

        self.startingcash = 0
        self.startingvalue = 0

        # 启动初始化
        self.flush_position()

    def flush_position(self, params={}):
        """
        binance交易所currency是btc的base，而okex是usdt，所以币安交易所的value其实是这个币的size
        """
        balance = self.store.get_wallet(params=params)
        self.logger.info("----------- Position -------------------")
        for currency in self.store.currency:
            free_size = balance[currency]["free"] if currency in balance else 0.0
            locked_size = balance[currency]["used"] if currency in balance else 0.0

            if currency in self.positions:
                self.positions[currency].fix(free_size, locked_size, 0)
            else:
                self.positions[currency] = CCXTPosition(currency, free_size, locked_size, "SPOT", 1, 0)
            self.logger.info("%s: free=%.10f, locked=%.10f", currency, free_size, locked_size)

        self.cash = self.positions["USDT"].free_size

    def load_opens_order(self, owner):
        symbols = list(set(map(lambda x: x.p.dataname, owner.datas)))
        _data = {d.p.dataname: d for d in owner.datas}

        # 加载的未成交委托单
        normal_orders = self.store.fetch_open_orders()
        normal_orders = list(filter(lambda x: x[self.mappings["order_result"]["symbol"]] in symbols, normal_orders))

        strat_orders = self.store.fetch_open_strat_orders(symbol=None, order_type=Order.StopLimit)
        strat_orders = list(filter(lambda x: x[self.mappings["order_strat_result"]["symbol"]] in symbols, strat_orders))

        orders = normal_orders + strat_orders
        if orders is None:
            orders = []

        for order in orders:
            if "algoId" in order:
                symbol = order[self.mappings["order_strat_result"]["symbol"]]
                ccxt_order = Okex5CCXTOrder(owner, _data[symbol], order, {}, self.mappings, self.order_types)
                self.open_orders.append(ccxt_order)
            else:
                symbol = order[self.mappings["order_result"]["symbol"]]
                if order[self.mappings["order_result"]["clientOrderId"]]:
                    # 这个说明是止盈止损单触发之后的限价单
                    ord_id = order[self.mappings["order_result"]["id"]]
                    stop_ccxt_order = self.store.search_strat_order(symbol, ord_id)
                    if not stop_ccxt_order:
                        self.logger.info(f"找到一个触发止损止盈单, 但是反过来查它的策略委托单的时候查不到, order_id={ord_id}")
                    ccxt_order = Okex5CCXTOrder(owner, _data[symbol], stop_ccxt_order, {}, self.mappings, self.order_types)

                    trigger_time = stop_ccxt_order[self.mappings['order_strat_result']['trade_dt']]
                    ccxt_order.trigger(trigger_time, ccxt_order)
                    self.open_orders.append(ccxt_order)
                else:
                    ccxt_order = CCXTOrder(owner, _data[symbol], order, {}, self.mappings, self.order_types)
                    self.open_orders.append(ccxt_order)


    """
    合约持仓用这个来判断(目前适配的是okex)
    """
    def fetch_position(self, symbol, posSide="long"):
        return self.store.fetch_position(symbol=symbol, posSide=posSide)

    def getposition(self, data):
        '''Returns the current position status (a ``Position`` instance) for
        the given ``data``'''
        currency = data.p.dataname.replace("/USDT", "")
        return self.positions[currency]

    def getposvalue(self, data):
        position = self.getposition(data)
        return position.total_size * data.close[0]

    def fetch_market(self, symbol):
        return self.store.fetch_market(symbol=symbol)

    def getcash(self):
        # Get cash seems to always be called before get value
        # Therefore it makes sense to add getbalance here.
        # return self.store.getcash(self.currency)
        return self.cash

    def getvalue(self, datas=None):
        # return self.store.getvalue(self.currency)
        self.value = 0
        if datas:
            for _data in datas:
                currency = _data.p.dataname.replace("/USDT", "")
                if currency in self.positions:
                    self.value += self.positions[currency].total_size * _data.close[0]
            self.value += self.positions["USDT"].total_size * 1
        return self.value

    def get_notification(self):
        try:
            return self.notifs.get(False)
        except queue.Empty:
            return None

    def notify(self, order):
        self.notifs.put(order.clone())

    def handle_order(self, o_order):
        oID = o_order.ccxt_order[self.mappings["order_result"]["id"]]
        # Get the order
        ccxt_order = self.store.fetch_order(oID, o_order.data.p.dataname)
        # self.logger.info("fetch_order: %s, %s", oID, o_order.data.p.dataname)

        # Check if the order is closed
        if ccxt_order[self.mappings['closed_order']['key']] == self.mappings['closed_order']['value']:
            dt = ccxt_order[self.mappings["order_result"]["trade_dt"]]
            amount = ccxt_order[self.mappings["order_result"]["size"]]
            trade_price = ccxt_order[self.mappings["order_result"]["trade_price"]]
            o_order.execute(dt, amount, trade_price, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
            o_order.completed()
            self.logger.info('Completed Order ID: %s', oID)
            self.notify(o_order)
            self.open_orders.remove(o_order)
            return True
        return False

    def handle_strat_order(self, o_order):
        oID = o_order.ccxt_order[self.mappings["order_strat_result"]["id"]]
        symbol = o_order.data.p.dataname
        if o_order.status == Order.Trigger:
            order_id = o_order.ref_order.id
            child_order = self.store.fetch_order(order_id, symbol)
            child_ccxt_order = o_order.ref_order
        else:
            ccxt_order = self.store.fetch_strat_order(symbol, o_order.exectype, oID)
            if ccxt_order[self.mappings['closed_strat_order']['key']] == self.mappings['closed_strat_order']['value']:
                order_id = ccxt_order[self.mappings['order_strat_result']['orderid']]
                trigger_time = ccxt_order[self.mappings['order_strat_result']['trade_dt']]
                child_order = self.store.fetch_order(order_id, symbol)
                child_ccxt_order = CCXTOrder(o_order.owner, o_order.data, child_order, None, self.mappings, self.order_types)
                o_order.trigger(trigger_time, child_ccxt_order)
                self.notify(o_order)
            else:
                return False

        if child_order[self.mappings['closed_order']['key']] == self.mappings['closed_order']['value']:
            dt = child_order[self.mappings["order_result"]["trade_dt"]]
            amount = child_order[self.mappings["order_result"]["size"]]
            trade_price = child_order[self.mappings["order_result"]["trade_price"]]
            child_ccxt_order.execute(dt, amount, trade_price, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
            o_order.execute(dt, amount, trade_price, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
            o_order.completed()
            self.open_orders.remove(o_order)
            self.notify(o_order)
            return True

        return False

    def next(self):
        if self.debug:
            self.logger.debug('Broker next() called')

        update_flag = False
        for o_order in list(self.open_orders):
            if o_order.exectype in [Order.Stop, Order.StopLimit, Order.Profit,
                                    Order.ProfitLimit, Order.StopProfit, Order.StopProfitLimit]:
                # 处理止损单子
                update_flag = self.handle_strat_order(o_order)
            else:
                update_flag = self.handle_order(o_order)

        if update_flag:
            self.flush_position()

    def _submit_strat(self, owner, data, exectype, side, amount, price, plimit, params):
        ret_ord = self.store.create_strat_order(data.p.dataname, exectype, side, amount, **params)
        algoid = ret_ord[self.mappings["order_strat_result"]["id"]]
        _order = self.store.fetch_strat_order(data.p.dataname, exectype, algoid)

        order = Okex5CCXTOrder(owner, data, _order, params, self.mappings, self.order_types)

        # plimit 可以认为是触发价
        # 买单和卖单其实就是信号价(不含滑点的价格)
        # 止损单其实就是止损触发价
        if not order.plimit and plimit:
            order.plimit = plimit

        self.open_orders.append(order)
        order.accept(self)
        self.flush_position()
        self.notify(order)
        return order

    def _submit_order(self, owner, data, exectype, side, amount, price, plimit, params):
        order_type = self.order_types.get(exectype) if exectype else 'market'
        created = int(data.datetime.datetime(0).timestamp() * 1000)
        # Extract CCXT specific params if passed to the order
        params = params['params'] if 'params' in params else params

        # # binance 不支持这个字段
        # if "tag" in params:
        #     del params["tag"]
        # params['created'] = created  # Add timestamp of order creation for backtesting
        ret_ord = self.store.create_order(symbol=data.p.dataname, order_type=order_type, side=side,
                                          amount=amount, price=price, params=params)

        _order = self.store.fetch_order(ret_ord['id'], data.p.dataname)
        # self.logger.info("fetch_order: %s, %s", ret_ord['id'], data.p.dataname)

        order = CCXTOrder(owner, data, _order, params, self.mappings, self.order_types)

        # plimit 可以认为是触发价
        # 买单和卖单其实就是信号价(不含滑点的价格)
        # 止损单其实就是止损触发价
        if not order.plimit and plimit:
            order.plimit = plimit

        self.open_orders.append(order)
        order.accept(self)

        # 无论是提交买单还是卖单, 必然会锁定资产, 对于资产的free和used有重大变化, 需要刷新账户
        self.flush_position()

        # 如果是market的话, 这里的order status是close状态的, 在next那里还会notify一次, 会做两次notify
        # 这里只通知提交但是不成交的提醒
        if order.ccxt_order[self.mappings['closed_order']['key']] != self.mappings['closed_order']['value']:
            self.notify(order)
        return order

    def _submit(self, owner, data, exectype, side, size, price, plimit, params):
        if exectype in [Order.Stop, Order.StopLimit, Order.Profit, Order.ProfitLimit, Order.StopProfitLimit]:
            return self._submit_strat(owner, data, exectype, side, size, price, plimit, params)
        else:
            return self._submit_order(owner, data, exectype, side, size, price, plimit, params)

    def buy(self, owner, data, size, price=None, plimit=None,
            exectype=None, valid=None, tradeid=0, oco=None,
            trailamount=None, trailpercent=None,
            **kwargs):
        del kwargs['parent']
        del kwargs['transmit']
        return self._submit(owner, data, exectype, 'buy', size, price, plimit, kwargs)

    def sell(self, owner, data, size, price=None, plimit=None,
             exectype=None, valid=None, tradeid=0, oco=None,
             trailamount=None, trailpercent=None,
             **kwargs):
        del kwargs['parent']
        del kwargs['transmit']
        return self._submit(owner, data, exectype, 'sell', size, price, plimit, kwargs)

    def cancel_order(self, order):
        oID = order.ccxt_order[self.mappings["order_result"]["id"]]

        if self.debug:
            print('Broker cancel() called')
            print('Fetching Order ID: {}'.format(oID))

        # check first if the order has already been filled otherwise an error
        # might be raised if we try to cancel an order that is not open.
        ccxt_order = self.store.fetch_order(oID, order.data.p.dataname)

        if self.debug:
            print(json.dumps(ccxt_order, indent=self.indent))

        if ccxt_order[self.mappings['closed_order']['key']].lower() == self.mappings['closed_order']['value'].lower():
            return order

        ccxt_order = self.store.cancel_order(oID, order.data.p.dataname)
        if self.debug:
            print(json.dumps(ccxt_order, indent=self.indent))
            print('Value Received: {}'.format(ccxt_order[self.mappings['canceled_order']['key']]))
            print('Value Expected: {}'.format(self.mappings['canceled_order']['value']))

        # 重新获取下订单
        ccxt_order = self.store.fetch_order(oID, order.data.p.dataname)

        if ccxt_order[self.mappings['canceled_order']['key']].lower() == self.mappings['canceled_order']['value'].lower():
            self.open_orders.remove(order)
            order.cancel()
            self.flush_position()
            self.notify(order)
        return order

    def cancel_strat_order(self, order):
        symbol = order.data.p.dataname
        if order.status == Order.Completed:
            return order

        oID = order.ccxt_order[self.mappings["order_strat_result"]["id"]]
        order_type = order.exectype

        if self.debug:
            print('Broker cancel() called')
            print('Fetching Order ID: {}'.format(oID))

        # check first if the order has already been filled otherwise an error
        # might be raised if we try to cancel an order that is not open.
        ccxt_order = self.store.fetch_strat_order(symbol, order_type, oID)

        # 如果策略委托单已经取消, 直接返回
        if ccxt_order[self.mappings['canceled_strat_order']['key']].lower() == self.mappings['canceled_strat_order']['value'].lower():
            return order

        if ccxt_order[self.mappings['closed_strat_order']['key']].lower() == self.mappings['closed_strat_order']['value'].lower():
            # 如果策略已经生效
            order_id = ccxt_order[self.mappings['order_strat_result']['orderid']]
            ccxt_order_cld = self.store.fetch_order(order_id, symbol)

            if ccxt_order_cld[self.mappings['closed_order']['key']].lower() != self.mappings['closed_order']['value'].lower():
                ccxt_order_cld = self.store.cancel_order(order_id, symbol)
                if self.debug:
                    print(json.dumps(ccxt_order_cld, indent=self.indent))
                    print('Value Received: {}'.format(ccxt_order_cld[self.mappings['canceled_order']['key']]))
                    print('Value Expected: {}'.format(self.mappings['canceled_order']['value']))

                ccxt_order_cld = self.store.fetch_order(order_id, symbol)
                if ccxt_order_cld[self.mappings['canceled_order']['key']].lower() == self.mappings['canceled_order']['value'].lower():
                    self.open_orders.remove(order)
                    order.cancel()
                    self.flush_position()
                    self.notify(order)
        else:
            # 如果策略没有生效
            ccxt_order = self.store.cancel_strat_order(symbol, oID)
            if ccxt_order[self.mappings['canceled_strat_order']['key']].lower() == self.mappings['canceled_strat_order']['value'].lower():
                self.open_orders.remove(order)
                order.cancel()
                self.flush_position()
                self.notify(order)
        return order

    def cancel(self, order):
        if order.exectype in [Order.Stop, Order.StopLimit, Order.Profit, Order.ProfitLimit, Order.StopProfitLimit]:
            return self.cancel_strat_order(order)
        else:
            return self.cancel_order(order)


    def private_end_point(self, type, endpoint, params):
        '''
        Open method to allow calls to be made to any private end point.
        See here: https://github.com/ccxt/ccxt/wiki/Manual#implicit-api-methods

        - type: String, 'Get', 'Post','Put' or 'Delete'.
        - endpoint = String containing the endpoint address eg. 'order/{id}/cancel'
        - Params: Dict: An implicit method takes a dictionary of parameters, sends
          the request to the exchange and returns an exchange-specific JSON
          result from the API as is, unparsed.

        To get a list of all available methods with an exchange instance,
        including implicit methods and unified methods you can simply do the
        following:

        print(dir(ccxt.hitbtc()))
        '''
        endpoint_str = endpoint.replace('/', '_')
        endpoint_str = endpoint_str.replace('{', '')
        endpoint_str = endpoint_str.replace('}', '')

        method_str = 'private_' + type.lower() + endpoint_str.lower()

        return self.store.private_end_point(type=type, endpoint=method_str, params=params)


